#!/usr/bin/env python3
"""
Extract and verify the N‑dependence of the one‑loop gauge β‑function.

This driver script extends the single‑N RG extraction to multiple
SU(N) gauge groups.  For each N in a predefined list, it samples the
running coupling g(μ) at several scales μ using a simple one‑loop
running model, computes the numerical β‑function via finite
differences, fits the data to the form β(g) = −β₁/(16 π²)·g³ and
records the fitted β₁.  A summary of β₁ as a function of N is saved
and plotted together with the theoretical expectation β₁ = 11 N/3.

All results (CSVs and figures) are written into this subdirectory’s
``results/`` directory.  When running from inside ``N_scaling``, the
script will create the directory if it does not already exist.

Usage::

    python scripts/extract_beta_N.py

When run with the working directory set to ``N_scaling`` it will
produce ``results/`` containing the raw data and plots and print the
fitted β₁ values and relative errors to standard output.
"""

import csv
import math
from pathlib import Path
from typing import List, Tuple

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit


def ensure_results_dir() -> Path:
    """Ensure that the ``results`` directory exists and return its path."""
    # results live two levels above this script: <N_scaling>/results
    root = Path(__file__).resolve().parent.parent
    results_dir = root / "results"
    results_dir.mkdir(parents=True, exist_ok=True)
    return results_dir


def save_csv(filepath: Path, rows: List[Tuple]):
    """Write an iterable of rows to a CSV file."""
    with open(filepath, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerows(rows)


def run_rg_sampler(mu: float, N: int, *, g0: float = 1.0,
                   mu0: float = 1.0, noise_std: float = 0.005) -> float:
    """
    Simulate the running coupling g(μ) for a pure SU(N) gauge theory.

    The function integrates the one‑loop RG equation analytically for
    the coupling and adds a small amount of Gaussian noise to mimic
    Monte Carlo fluctuations.  The random seed must be set by the
    caller to ensure reproducibility.

    Parameters
    ----------
    mu : float
        Energy scale.
    N : int
        The number of colours of the SU(N) gauge group.
    g0 : float, optional
        Reference coupling at scale μ₀ (default 1.0).
    mu0 : float, optional
        Reference energy scale (default 1.0).
    noise_std : float, optional
        Standard deviation of relative noise applied to g (default 0.5 %).

    Returns
    -------
    float
        The running coupling g(μ) with noise.
    """
    beta1 = 11 * N / 3
    # coefficient in 1/g² solution
    b = beta1 / (8 * math.pi ** 2)
    inv_g2 = 1.0 / (g0 ** 2) + b * math.log(mu / mu0)
    g_mu = inv_g2 ** (-0.5)
    # multiplicative noise
    noise = np.random.normal(loc=0.0, scale=noise_std)
    return g_mu * (1 + noise)


def one_loop_beta(g: float, beta1: float) -> float:
    """Theoretical one‑loop β(g) = −β₁/(16 π²)·g³."""
    return -beta1 / (16 * math.pi ** 2) * g ** 3


def main():
    # deterministic random seed for reproducibility
    np.random.seed(42)

    # gauge groups to test and energy scales
    Ns = [2, 4]
    mus = [8, 16, 32, 64]

    # accumulate g vs μ data across all N
    g_data: List[Tuple[int, float, float]] = []
    for N in Ns:
        for mu in mus:
            g_mu = run_rg_sampler(mu, N)
            g_data.append((N, mu, g_mu))

    # save raw g vs mu data
    results_dir = ensure_results_dir()
    g_vs_mu_N_path = results_dir / "g_vs_mu_N.csv"
    save_csv(g_vs_mu_N_path, g_data)

    # load data for processing
    arr = np.loadtxt(g_vs_mu_N_path, delimiter=",")
    # ensure 2D shape even if only one row
    if arr.ndim == 1:
        arr = arr.reshape(1, -1)

    # compute β_numeric grouped by N
    beta_numeric_rows: List[Tuple[int, float, float]] = []  # (N, mu, beta_numeric)
    beta1_fit_rows: List[Tuple[int, float]] = []  # (N, beta1_fit)

    for N in Ns:
        # select rows for this N
        mask = arr[:, 0] == N
        sub = arr[mask]
        # sort by mu
        idx = np.argsort(sub[:, 1])
        mus_n = sub[idx, 1]
        gs_n = sub[idx, 2]
        # finite differences on ln mu
        dlog_mu = np.diff(np.log(mus_n))
        dg = np.diff(gs_n)
        beta_numeric = dg / dlog_mu
        # store rows; note that β_numeric corresponds to points mus_n[1:]
        for i in range(len(beta_numeric)):
            beta_numeric_rows.append((int(N), mus_n[i + 1], beta_numeric[i]))
        # fit β_numeric vs g
        # using gs_n[1:] because beta_numeric length = len(gs_n)-1
        popt, _ = curve_fit(lambda g, beta1: one_loop_beta(g, beta1),
                            gs_n[1:], beta_numeric)
        beta1_fit = popt[0]
        beta1_fit_rows.append((int(N), beta1_fit))

    # save β_numeric data
    beta_numeric_path = results_dir / "beta_numeric_N.csv"
    save_csv(beta_numeric_path, beta_numeric_rows)

    # save β₁_fit vs N
    beta1_vs_N_path = results_dir / "beta1_vs_N.csv"
    save_csv(beta1_vs_N_path, beta1_fit_rows)

    # create plot of β₁_fit vs N with theoretical line
    Ns_arr = np.array([row[0] for row in beta1_fit_rows], dtype=float)
    beta1_fit_arr = np.array([row[1] for row in beta1_fit_rows], dtype=float)
    # theoretical β₁ = 11N/3
    beta1_theory_arr = 11 * Ns_arr / 3
    plt.figure()
    plt.plot(Ns_arr, beta1_fit_arr, 'o', label=r"Fitted $\beta_1$")
    # plot theory line connecting the points for visual comparison
    plt.plot(Ns_arr, beta1_theory_arr, 'r--', label=r"Theory $\beta_1=11N/3$")
    plt.xlabel("N")
    plt.ylabel(r"$\beta_1$")
    plt.title(r"One‑loop coefficient $\beta_1$ vs N")
    plt.legend()
    plt.tight_layout()
    beta1_plot_path = results_dir / "beta1_scaling.png"
    plt.savefig(beta1_plot_path)
    plt.close()

    # plot g(μ) vs ln μ overlayed for N=2 and N=4
    plt.figure()
    for N in Ns:
        mask = arr[:, 0] == N
        sub = arr[mask]
        # sort by mu to ensure monotonic ln μ
        idx = np.argsort(sub[:, 1])
        mus_n = sub[idx, 1]
        gs_n = sub[idx, 2]
        plt.plot(np.log(mus_n), gs_n, marker='o', linestyle='-',
                 label=f"N={int(N)}")
    plt.xlabel(r"$\ln\mu$")
    plt.ylabel(r"$g(\mu)$")
    plt.title(r"Running coupling $g(\mu)$ vs $\ln\mu$ for different N")
    plt.legend()
    plt.tight_layout()
    g_plot_path = results_dir / "g_vs_lnmu_N.png"
    plt.savefig(g_plot_path)
    plt.close()

    # print fitted β₁ and relative errors to stdout
    for N_val, beta1_fit in beta1_fit_rows:
        beta1_theory = 11 * N_val / 3
        rel_error = abs(beta1_fit - beta1_theory) / beta1_theory
        print(f"N={int(N_val)}: β1_fit={beta1_fit:.4f}, theory={beta1_theory:.4f}, "
              f"relative error={rel_error * 100:.2f}%")


if __name__ == "__main__":
    main()